from pytdx.hq import TdxHq_API as hq
from pytdx.exhq import TdxExHq_API as exhq
import sys,os
def get_market(code):
    code=str(code)[-6:].zfill(6) 
    if code.startswith('6') or code.startswith('9') \
      or code.startswith('11') or code.startswith('13') \
      or code.startswith('5'):
      return 1,code #sh
    else:
      return 0,code

class mytdx:
    def __init__(self,tdxhq_host,tdxhq_port,auto_connect=True,
        tdxexhq_host=None,tdxexhq_port=None):
        self.host = tdxhq_host
        self.port = tdxhq_port
        self._hq = hq(
            heartbeat=True, auto_retry=True
             #, raise_exception=True
             )
        if auto_connect == True:
            #if not _hq.connect(tdxhq_host, tdxhq_port):
            if not self.conn():
                print('failed tdxhq conn')
                os._exit(0) #tmp
        if tdxexhq_host and tdxexhq_port:
            _exhq = exhq(
                #multithread=True,
                heartbeat=True,
                auto_retry=True
                #, raise_exception=True
               )
            if _exhq and _exhq.connect(tdxexhq_host, tdxexhq_port):
                print('ok tdxexhq conn',tdxexhq_host,tdxexhq_port)
                pass
            else:
                print('fail tdxexhq conn',type(_ex_hq),tdxexhq_host,tdxexhq_port)
                os._exit(0) #tmp
            self._exhq = _exhq
        self.get_market = get_market

    def conn(self):
        return self._hq.connect(self.host,self.port)
    def disconn(self):
        return self._hq.disconnect(self.host,self.port)

    def get_tick_yyyy(self, code, yyyy, freq='1min'):
        from mypy import time_maker
        rsa = []
        get_history_transaction_data = self._hq.get_history_transaction_data
        mk2 = code[-2:]
        if mk2=='SZ':
            mk2 = 0
            code6 = code[:6]
        elif mk2=='SH':
            mk2 = 1
            code6 = code[:6]
        else:
            mk2,code6 = get_market(code)
        #print('dbg',mk2,code6)
        i_date = start_date = f'{yyyy}0101'
        end_date = '{}0101'.format(1+int(yyyy))
        end1_date = time_maker(1,outfmt='%Y%m%d')
        #end2_date = time_maker(2,outfmt='%Y%m%d')
        if end_date>end1_date: end_date = end1_date
        page_size = 2000 # fix by tdx protocol
        while i_date < end_date:
            print('down',i_date,mk2,code6)
            i = 0
            while True:
                a = get_history_transaction_data(mk2, code6, i, page_size, int(i_date))
                if a is not None and len(a)>0:
                    ax = []
                    for v in a:
                        #vv = (i_date,v['time'],v['price'],v['vol'],v['buyorsell'])
                        vv = (1000 * time_maker(date='{} {}:00'.format(i_date,v['time']),infmt='%Y%m%d %H:%M:%S',outfmt=''),v['price'],v['vol'],v['buyorsell'])
                        ax.append(vv)
                    rsa += ax
                else:
                    break
                i += page_size
            i_date = time_maker(1,date=i_date,infmt='%Y%m%d')
        import pandas as pd
        df = pd.DataFrame(rsa).sort_values(0)
        return df
        
    def get_df_yyyy(self, code, yyyy, freq='1min'):
        #TODO from mypy import parallel
        rsa = []
        get_security_bars = self._hq.get_security_bars
        if freq=='1min':
            pass
        else:
            assert False, f'TODO freq {freq}'
        # TODO: pytdx这里远程获取分钟OHLCVA不太顺利，估计需要用reader读取本地文件

#        for a in get_history_minute_time_data(market, code6, yyyymmdd)
#.get_security_bars(7,0,'000001',0,800)
#        get_history_minute_time_data = self._hq.get_history_minute_time_data
#tdx._hq.get_history_minute_time_data(0,'000001',20220104)

#GetHistoryMinuteTimeData(nMarket, sStockCode, nDate)
# notes: b[i]['price']
#GetTransactionData 获取当日指定范围的分笔成交数据 
#GetTransactionData(nMarket, sStockCode, nStart, nCount) 
#GetXDXRInfo 获取除权除息数据 
#GetXDXRInfo(nMarket, sStockCode) 
#GetSecurityBars 获取某个范围内的证券K 线数据 
#GetSecurityBars(nCategory, nMarket, sStockCode, nStart, nCount) 
#获取市场内指定范围的证券K 线， 
#指定开始位置和指定K 线数量，指定数量最大值为800。 
#参数： 
#nCategory -> K 线种类 
#0 5 分钟K 线 
#1 15 分钟K 线 
#2 30 分钟K 线 
#3 1 小时K 线 
#4 日K 线 
#5 周K 线 
#6 月K 线 
#7 1 分钟 
#8 1 分钟K 线 
#9 日K 线 
#10 季K 线 
#11 年K 线 
#nMarket -> 市场代码0:深圳，1:上海 
#sStockCode -> 证券代码； 
#nStart -> 指定的范围开始位置； 
#nCount -> 用户要请求的K 线数目，最大值为800。
# e.g. tdx._hq.get_security_bars(7,0,'000001',0,800)

    def get_security_quotes_80(self, code_a):
        return self._hq.get_security_quotes([get_market(c) for c in code_a])
    
    def get_market_records_80(self, code_a):
        rs = self.get_security_quotes_80(code_a)
        if rs is None: return []
        df = self._hq.to_df(rs)
        if df.empty: return []
        df1=df[['code','market','last_close','price'#,'servertime'
        ]]
        #print(df1)
        dict_a = df1.to_dict(orient='records')
        rt = []
        for _row in dict_a:
            _code = _row.get('code',None)
            if _code is not None:
                if _code.startswith('1') or _code.startswith('5'):# tmp hck ^5... # todo...
                    _row['price'] = round(float(_row['price']/10),3)
                    _row['last_close']=round(float(_row['last_close']/10),3)
                rt.append(_row)
        if len(rt)<1:
            print('!!! get_market_records_80 empty for',code_a)
        return rt

    def get_market_records(self, _code_a):
        code_a = []
        for code in _code_a:
            code6=str(code)[-6:].zfill(6) 
            if code6.startswith('7'):#not yet support
                print('skip code',code)
                continue
            code_a.append(code)
        rt = []
        while True:
            code_a_80 = code_a[0:80]
            if not len(code_a_80)>0: break
            rt1 = self.get_market_records_80(code_a_80)
            if len(rt1)<1: print('WARNING',code_a_80,rt1)
            rt += rt1
            code_a = code_a[80:]
        return rt

def test():
    tdx=mytdx(#'47.103.48.45'
              #'120.79.60.82',
              #'sztdx.gtjas.com',
              '139.159.143.228',#gz gbp
              7709)
    return tdx

def testex():
    return mytdx('139.159.143.228',#gz gbp
              7709,auto_connect=True,
#tdxexhq_host='106.14.95.149',
tdxexhq_host='112.74.214.43',
#112.74.214.43
#120.25.218.6
#47.107.75.159
#119.97.185.5
#59.175.238.38
#106.14.95.149
#47.102.108.214
tdxexhq_port=7727)
if '__main__'==__name__:
    tdx=test()
    print(tdx.get_market_records([999999,399001,399006,399300,399905,725]))
"""
TODO for QMT...
[(v[0],v[1]['lastPrice'],v[1]['lastClose'],v[1]['timetag']) for v in g_ctx.get_full_tick(['000001.SH','399001.SZ','399006.SZ','399300.SZ','399905.SZ']).items()]
"""
